from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch

from src.utils.sparse_utils import index_subset_mask


class Backbone(nn.Module):
    """Backone for PyG based models."""

    def __init__(
        self, net: nn.Module, input_adapter: nn.Module, output_adapter: nn.Module
    ) -> None:
        """Initialize backbone.

        :param net: PyG model.
        :param input_adapter: Input adapter.
        :param output_adapter: Output adapter.
        """
        super().__init__()

        self.net = net
        self.input_adapter = input_adapter
        self.output_adapter = output_adapter

    def forward(self, data: Batch) -> Float[Tensor, "n_edges"]:
        """Perform a forward pass with the backbone.

        :param data: PyG batch object.
        :return: Prediction for each edge.
        """
        x, edge_index, edge_attr = self.input_adapter(data)

        x = self.net(x, edge_index, edge_attr)

        if len(edge_index[0]) > len(data.edge_index[0]):
            subset_mask = index_subset_mask(edge_index, data.edge_index)
            edge_attr = edge_attr[subset_mask]
        return self.output_adapter(x, data.edge_index, edge_attr)
