import torch
from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch

from src.utils.adjacency_utils import get_adj_indices
from src.utils.sparse_utils import sparse_softmax


class ATTOutputAdapter(nn.Module):
    """OutputAdapter for ATT."""

    def __init__(
        self,
        n: int,
    ) -> None:
        """Initialize the ATTOutputAdapter.

        :param n: Number of maximal nodes in one graph. Needed for padding.
        """
        super().__init__()

        self.n = n

    def forward(
        self,
        data: Batch,
        A_pred: Float[Tensor, "b n_leaves n_internal"],
        B_pred: Float[Tensor, "b n_internal n_internal"],
    ) -> Float[Tensor, "n_edges"]:
        """Computes edge probabilities from the output of the GNN.

        :param data: PyG batch object.
        :param A_pred: Predicted adjacency matrix A.
        :param B_pred: Predicted adjacency matrix B.
        :return: Predicted edge probabilities.
        """
        adj = torch.cat([A_pred, B_pred], dim=1)
        b_index, l_index, i_index = get_adj_indices(data, self.n)
        logits = adj[b_index, l_index, i_index]
        return sparse_softmax(data.edge_index, logits, axis=1)
