from typing import Tuple

from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch

from src.utils.adjacency_utils import get_A, get_B, get_XA


class ATTInputAdapter(nn.Module):
    """InputAdapter for ATT."""

    def __init__(
        self,
        n: int,
    ) -> None:
        """Initialize the ATTInputAdapter.

        :param n: Number of maximal nodes in one graph. Needed for padding.
        """
        super().__init__()

        self.n = n

    def forward(self, data: Batch) -> Tuple[
        Float[Tensor, "b n_leaves d"],
        Float[Tensor, "b n_leaves n_internal"],
        Float[Tensor, "b n_internal n_internal"],
    ]:
        """Compute batched adjacency matrices A and B as well as the feature matrix XA.

        :param data: PyG batch object.
        :return: A, B and XA.
        """
        XA = get_XA(data, self.n)
        A = get_A(data, self.n)
        B = get_B(data, self.n)
        return XA, A, B
