from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch

from src.utils.sparse_utils import index_subset_mask
from src.models.backbone.input_adapter.edge_features.index_selection import IndexSelection


class Backbone(nn.Module):
    """Backone for PyG based models."""

    def __init__(
        self, net: nn.Module, input_adapter: nn.Module, output_adapter: nn.Module, n: int, k: int
    ) -> None:
        """Initialize backbone.

        :param net: PyG model.
        :param input_adapter: Input adapter.
        :param output_adapter: Output adapter.
        """
        super().__init__()

        self.net = net
        self.input_adapter = input_adapter
        self.output_adapter = output_adapter
        self.index_selector = IndexSelection(n, k)

    def forward(self, data: Batch) -> Float[Tensor, "n_edges"]:
        """Perform a forward pass with the backbone.

        :param data: PyG batch object.
        :return: Prediction for each edge.
        """
        x, edge_index, edge_attr = self.input_adapter(data)

        index_selection = self.index_selector(data, edge_index)
        edge_index_selection = edge_index[:, index_selection]
        edge_attr_selection = edge_attr[index_selection]

        x = self.net(x, edge_index_selection, edge_attr_selection)

        subset_mask = index_subset_mask(edge_index, data.edge_index)
        edge_attr = edge_attr[subset_mask]
        return self.output_adapter(x, data.edge_index, edge_attr)
