from jaxtyping import Float
from torch import Tensor, nn

from src.utils.gnn_utils import make_gnn


class VanillaGNN(nn.Module):
    """Simple graph neural network."""

    def __init__(self, conv: str, n_layer: int, d_hidden: int, d_edge: int) -> None:
        """Initialize the residual graph neural network.

        :param conv: Convolutional layer to use.
        :param n_layers: Number of layers.
        :param d_hidden: Number of hidden channels.
        :param d_edge: Number of edge features.
        """
        super().__init__()

        self.gnn = make_gnn(conv, n_layer, d_hidden)

    def forward(
        self,
        x: Float[Tensor, "n_nodes d_hidden"],
        edge_index: Float[Tensor, "2 n_edges"],
        edge_attr: Float[Tensor, "n_edges"],
    ) -> Float[Tensor, "n_nodes d_hidden"]:
        """Perform a forward pass with the graph neural network.

        :param data: PyG batch object.
        """
        return self.gnn(x, edge_index, edge_attr)
