import torch
from jaxtyping import Float
from torch import Tensor, nn

from src.utils.backbone_utils import make_nn


class NetHead(nn.Module):
    """NetHead: Predicts probability for each edge in the graph using a neural network."""

    def __init__(
        self,
        n_layers: int,
        d_hidden: int,
        d_edge: int,
    ) -> None:
        """Initialize the NetHead.

        :param d_input: Number of features of input.
        :param d_output: Number of features of output.
        """
        super().__init__()

        self.d_edge = d_edge
        self.net = make_nn(
            n_layers=n_layers, d_input=2 * d_hidden + d_edge, d_hidden=d_hidden, d_output=1
        )

    def forward(
        self,
        x: Float[Tensor, "n_nodes d_node"],
        edge_index: Float[Tensor, "2 n_edges"],
        edge_attr: Float[Tensor, "n_edges d_edge"],
    ) -> Tensor:
        """Predicts probability for each edge in the graph using a neural network.

        :param x: Node features.
        :param edge_index: Edge index.
        :param edge_attr: Edge features.
        :return: Predicted probability for each edge in the graph.
        """
        start, end = edge_index
        if self.d_edge == 1:
            edge_attr = edge_attr.unsqueeze(1)
        nn_input = torch.cat([x[start], x[end], edge_attr], dim=1)
        return self.net(nn_input).squeeze()
