import torch
from jaxtyping import Float
from torch import Tensor, nn

from src.utils.backbone_utils import make_nn


class DotHead(nn.Module):
    """DotHead: Predicts probability for each edge in the graph using the dot product of node features."""

    def __init__(
        self,
        n_layers: int,
        d_edge: int,
    ) -> None:
        """Initialize the DotHead.

        :param d_input: Number of features of input.
        :param d_output: Number of features of output.
        """
        super().__init__()

        self.d_edge = d_edge
        self.net = make_nn(n_layers=n_layers, d_input=d_edge + 1, d_hidden=d_edge + 1, d_output=1)

    def forward(
        self,
        x: Float[Tensor, "n_nodes d_node"],
        edge_index: Float[Tensor, "2 n_edges"],
        edge_attr: Float[Tensor, "n_edges d_edge"],
    ) -> Tensor:
        """Predicts probability for each edge in the graph using the dot product of node features.

        :param x: Node features.
        :param edge_index: Edge index.
        :param edge_attr: Edge features.
        :return: Predicted probability for each edge in the graph.
        """
        start, end = edge_index
        dot_product = (x[start] * x[end]).sum(dim=1, keepdim=True)
        if self.d_edge == 1:
            edge_attr = edge_attr.unsqueeze(1)
        nn_input = torch.cat([edge_attr, dot_product], dim=1)
        return self.net(nn_input).squeeze()
